module Hasura.RQL.DML.Count
  ( CountQueryP1 (..),
    validateCountQWith,
    validateCountQ,
    runCount,
    countQToTx,
  )
where

import Control.Monad.Trans.Control (MonadBaseControl)
import Data.Aeson
import Data.ByteString.Builder qualified as BB
import Data.Sequence qualified as DS
import Database.PG.Query qualified as Q
import Hasura.Backends.Postgres.SQL.DML qualified as S
import Hasura.Backends.Postgres.SQL.Types
import Hasura.Backends.Postgres.Translate.BoolExp
import Hasura.Base.Error
import Hasura.EncJSON
import Hasura.Prelude
import Hasura.RQL.DML.Internal
import Hasura.RQL.DML.Types
import Hasura.RQL.IR.BoolExp
import Hasura.RQL.Types
import Hasura.SQL.Types
import Hasura.Session
import Hasura.Tracing qualified as Tracing

data CountQueryP1 = CountQueryP1
  { cqp1Table :: !QualifiedTable,
    cqp1Where :: !(AnnBoolExpSQL ('Postgres 'Vanilla), Maybe (AnnBoolExpSQL ('Postgres 'Vanilla))),
    cqp1Distinct :: !(Maybe [PGCol])
  }
  deriving (Eq)

mkSQLCount ::
  CountQueryP1 -> S.Select
mkSQLCount (CountQueryP1 tn (permFltr, mWc) mDistCols) =
  S.mkSelect
    { S.selExtr = [S.Extractor S.countStar Nothing],
      S.selFrom =
        Just $
          S.FromExp
            [S.mkSelFromExp False innerSel $ TableName "r"]
    }
  where
    finalWC =
      toSQLBoolExp (S.QualTable tn) $
        maybe permFltr (andAnnBoolExps permFltr) mWc

    innerSel =
      partSel
        { S.selFrom = Just $ S.mkSimpleFromExp tn,
          S.selWhere = S.WhereFrag <$> Just finalWC
        }

    partSel = case mDistCols of
      Just distCols ->
        let extrs = flip map distCols $ \c -> S.Extractor (S.mkSIdenExp c) Nothing
         in S.mkSelect
              { S.selDistinct = Just S.DistinctSimple,
                S.selExtr = extrs
              }
      Nothing ->
        S.mkSelect
          { S.selExtr = [S.Extractor (S.SEStar Nothing) Nothing]
          }

-- SELECT count(*) FROM (SELECT DISTINCT c1, .. cn FROM .. WHERE ..) r;
-- SELECT count(*) FROM (SELECT * FROM .. WHERE ..) r;
validateCountQWith ::
  (UserInfoM m, QErrM m, TableInfoRM ('Postgres 'Vanilla) m) =>
  SessionVariableBuilder ('Postgres 'Vanilla) m ->
  (ColumnType ('Postgres 'Vanilla) -> Value -> m S.SQLExp) ->
  CountQuery ->
  m CountQueryP1
validateCountQWith sessVarBldr prepValBldr (CountQuery qt _ mDistCols mWhere) = do
  tableInfo <- askTabInfoSource qt

  -- Check if select is allowed
  selPerm <-
    modifyErr (<> selNecessaryMsg) $
      askSelPermInfo tableInfo

  let colInfoMap = _tciFieldInfoMap $ _tiCoreInfo tableInfo

  forM_ mDistCols $ \distCols -> do
    let distColAsrns =
          [ checkSelOnCol selPerm,
            assertColumnExists colInfoMap relInDistColsErr
          ]
    withPathK "distinct" $ verifyAsrns distColAsrns distCols

  -- convert the where clause
  annSQLBoolExp <- forM mWhere $ \be ->
    withPathK "where" $
      convBoolExp colInfoMap selPerm be sessVarBldr qt (valueParserWithCollectableType prepValBldr)

  resolvedSelFltr <-
    convAnnBoolExpPartialSQL sessVarBldr $
      spiFilter selPerm

  return $
    CountQueryP1
      qt
      (resolvedSelFltr, annSQLBoolExp)
      mDistCols
  where
    selNecessaryMsg =
      "; \"count\" is only allowed if the role "
        <> "has \"select\" permissions on the table"
    relInDistColsErr =
      "Relationships can't be used in \"distinct\"."

validateCountQ ::
  (QErrM m, UserInfoM m, CacheRM m) =>
  CountQuery ->
  m (CountQueryP1, DS.Seq Q.PrepArg)
validateCountQ query = do
  let source = cqSource query
  tableCache :: TableCache ('Postgres 'Vanilla) <- askTableCache source
  flip runTableCacheRT (source, tableCache) $
    runDMLP1T $
      validateCountQWith sessVarFromCurrentSetting binRHSBuilder query

countQToTx ::
  (QErrM m, MonadTx m) =>
  (CountQueryP1, DS.Seq Q.PrepArg) ->
  m EncJSON
countQToTx (u, p) = do
  qRes <-
    liftTx $
      Q.rawQE
        dmlTxErrorHandler
        (Q.fromBuilder countSQL)
        (toList p)
        True
  return $ encJFromBuilder $ encodeCount qRes
  where
    countSQL = toSQL $ mkSQLCount u
    encodeCount (Q.SingleRow (Identity c)) =
      BB.byteString "{\"count\":" <> BB.intDec c <> BB.char7 '}'

runCount ::
  ( QErrM m,
    UserInfoM m,
    CacheRM m,
    MonadIO m,
    MonadBaseControl IO m,
    Tracing.MonadTrace m,
    MetadataM m
  ) =>
  CountQuery ->
  m EncJSON
runCount q = do
  sourceConfig <- askSourceConfig @('Postgres 'Vanilla) (cqSource q)
  validateCountQ q >>= runTxWithCtx (_pscExecCtx sourceConfig) Q.ReadOnly . countQToTx
